# Copyright (c) 2023 az13js
# 基金分类/classify-funds is licensed under Mulan PSL v2.
# You can use this software according to the terms and conditions of
# the Mulan PSL v2.
# You may obtain a copy of Mulan PSL v2 at:
#          http://license.coscl.org.cn/MulanPSL2
# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF
# ANY KIND,EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO
# NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
# See the Mulan PSL v2 for more details.

"""
K-Means算法的实现。
"""

import math
import csv
import random

class Sample(object):
    """样本，或者分类的中心点。

    假如一个二维平面的点，(1, 2)，那么 Sample([1, 2]) 可以用来代表这个点。
    """

    def __init__(self, data = [], label = ''):
        """初始化，参数data是每一个维度的数据组成的列表。"""
        self._data = data
        self._label = label

    def getData(self):
        """以列表的形式获取样本的坐标。"""
        return self._data

    def getLabel(self):
        """返回标签。"""
        return self._label

    def lengthWith(self, anotherSample):
        """返回当前样本与指定的样本的距离。"""
        anotherData = anotherSample.getData()
        total = len(anotherData)
        sumValue = 0
        for i in range(total):
            sumValue = sumValue + pow(self._data[i] - anotherData[i], 2)
        return math.sqrt(sumValue)

    def copyMyself(self):
        """复制并返回新的Sample实例。这个实例和当前实例的坐标是一样的。"""
        return Sample(self._data, self._label)

    def gauss(self, sigma = 1.0):
        """使用正态分布方法对当前实例的坐标进行随机地移动。可以用sigma控制移动的随机程度，sigma越大移动距离越大。"""
        total = len(self._data)
        for i in range(total):
            self._data[i] = random.gauss(self._data[i], sigma)

class Category(object):
    """代表样本所属的分类。

    必须指定分类的中心点，比如二维平面，分类的中心点是坐标(2, 2)，那么实例化分类需要这样做：Category(Sample([2, 2]))。
    注意传给分类的Sample实例不可以再用于样本的表示。
    """

    def __init__(self, center):
        """实例化分类类，center指定分类的中心位置。在K-Means算法初始化的时候，可以指定一个随机的中心。"""
        self._center = center
        self._samples = []

    def getCenter(self):
        """返回当前分类的中心。返回的是一个Sample实例，它的坐标是分类的中心。"""
        return self._center

    def addSample(self, sample):
        """往分类实例中添加样本，用意是这些样本属于这个分类。"""
        self._samples.append(sample)

    def cleanSample(self):
        """清空通过addSample添加的样本。"""
        self._samples = []

    def updateCenter(self):
        """通过addSample添加样本之后，此方法将通过样本计算和更新当前分类的中心。"""
        if len(self._samples) == 0:
            return
        dataLength = len(self._samples[0].getData())
        sumDist = [0 for i in range(dataLength)]

        for s in self._samples:
            data = s.getData()
            for i in range(dataLength):
                sumDist[i] = sumDist[i] + data[i]

        sampleTotal = len(self._samples)
        for i in range(dataLength):
            sumDist[i] = sumDist[i] / sampleTotal
        self._center = Sample(sumDist)

def createData(csvFileName, dataTotal = 10, centersPositions = None):
    """随机地创建样本，创建的样本将写入文件csvFileName。

    参数和参数的含义：
        csvFileName CSV文件名，会创建并向这个文件写数据。
        dataTotal 随机生成多少个样本。
        centersPositions 可选，你可以用这个参数指定样本的分类中心。
    返回值：
        无。
    """
    if centersPositions is None:
        centers = [{'x': 5, 'y': 5}, {'x': -5, 'y': -5}]
    else:
        centers = centersPositions
    sigma = 1.0
    with open(csvFileName, 'w', newline='') as fp:
        write = csv.writer(fp)
        write.writerow(['ID', 'x', 'y'])
        for i in range(dataTotal):
            center = random.choices(centers)[0]
            x = random.gauss(center['x'], sigma)
            y = random.gauss(center['y'], sigma)
            write.writerow([i + 1, x, y])

def readData(csvFileName):
    """从文件中读取数据，迭代并返回列表。

    参数和参数的含义：
        csvFileName CSV文件名
    返回值：
        通过生成器返回列表。
    """
    lineNo = 0
    with open(csvFileName, 'r') as fp:
        reader = csv.reader(fp)
        for row in reader:
            lineNo = lineNo + 1
            if 1 == lineNo:
                continue
            yield [float(row[1]), float(row[2])]

class AlgorithmLogic(object):
    """K-Means算法主体逻辑。

    你需要先创建此类的实例，然后用addSampleData添加多个样本，再用run来执行这个算法。
    """

    def __init__(self):
        """不用传什么参数，这个初始化方法只是设置几个私有的成员。"""
        self._samples = []
        self._categorys = []
        self._centerMoves = []
        self._centerMoveLength = []

    def addSampleData(self, data, label = ''):
        """添加样本点。data参数是列表，代表样本的坐标。label是字符串，一个区分数据用的标签，相当于备注。你需要保证这些样本的维数是一样的。"""
        self._samples.append(Sample(data, label))

    def getCategorys(self):
        """返回所有的分类对象。"""
        return self._categorys

    def getCenterMoveSteps(self):
        """返回所有的分类的中心在每一次迭代后移动的距离总和的变化情况。返回值是一个列表。"""
        return self._centerMoveLength

    def getCenterMoveDetails(self):
        """返回迭代时分类中心的变化情况。返回值是二维的列表，列表的每一个元素是所有的分类中心曾经的位置。"""
        return self._centerMoves

    def run(self, categoryTotal = 2, loopTotal = 50, minLength = 1E-3):
        """运行算法。categoryTotal指定分类数量，loopTotal是最大迭代次数，minLength意味着分类中心移动量总和小于这个数字的时候停止迭代。"""
        self._categorys = []
        self._centerMoves = []
        self._centerMoveLength = []

        for c in random.choices(self._samples, k = categoryTotal):
            centerPoint = c.copyMyself()
            centerPoint.gauss()
            self._categorys.append(Category(centerPoint))

        move = []
        for c in self._categorys:
            move.append(c.getCenter().getData())
        self._centerMoves.append(move)

        for i in range(loopTotal):

            for s in self._samples:
                minCate = None
                minCateLength = None
                for c in self._categorys:
                    l1 = s.lengthWith(c.getCenter())
                    if minCateLength is None or l1 < minCateLength:
                        minCate = c
                        minCateLength = l1
                minCate.addSample(s)

            for c in self._categorys:
                c.updateCenter()
                c.cleanSample()

            lastMove = self._centerMoves[-1]
            move = []
            for c in self._categorys:
                move.append(c.getCenter().getData())
            self._centerMoves.append(move)

            moveLength = 0
            dataLength = len(move[0])
            for k in range(categoryTotal):
                value = 0
                for e in range(dataLength):
                    value = value + pow(lastMove[k][e] - move[k][e], 2)
                moveLength = moveLength + math.sqrt(value)
            self._centerMoveLength.append(moveLength)
            if moveLength < minLength:
                break

    def getSilhouetteCoefficient(self):
        """返回轮廓系数"""
        categorys = []
        categorysNum = [0 for i in range(len(self._categorys))]
        lengthMatrix = [[0 for j in range(len(self._samples))] for i in range(len(self._samples))]
        for i in range(len(self._samples)):
            for j in range(len(self._samples)):
                if i == j:
                    lengthMatrix[i][j] = 0.0
                else:
                    if j > i:
                        lengthMatrix[i][j] = self._samples[i].lengthWith(self._samples[j])
                    else:
                        lengthMatrix[i][j] = lengthMatrix[j][i]

        for s in self._samples:
            minCategoryKey = None
            minLength = None
            for i in range(len(self._categorys)):
                l = s.lengthWith(self._categorys[i].getCenter())
                if minCategoryKey is None or l < minLength:
                    minCategoryKey = i
                    minLength = l
            categorys.append(minCategoryKey)
            categorysNum[minCategoryKey] = categorysNum[minCategoryKey] + 1

        # 每个样本与所有和自己同类的其它样本的距离的平均值。
        a = [0.0 for i in range(len(self._samples))]
        # 每个样本与非自己的分类的所有样本的平均值最小值。
        b = [[None for j in range(len(self._categorys))] for i in range(len(self._samples))]
        # 轮廓系数
        s = []
        for i in range(len(self._samples)):
            for j in range(len(self._samples)):
                if i == j:
                    continue
                if categorys[i] == categorys[j]:
                    a[i] = a[i] + lengthMatrix[i][j]
                else:
                    if b[i][categorys[j]] is None:
                        b[i][categorys[j]] = 0.0
                    b[i][categorys[j]] = b[i][categorys[j]] + lengthMatrix[i][j]
            a[i] = a[i] / categorysNum[categorys[i]]
            item = []
            for k in range(len(self._categorys)):
                if b[i][k] is not None:
                    item.append(b[i][k] / categorysNum[k])
            b[i] = min(item)
        s.append((b[i] - a[i]) / max(a[i], b[i]))
        sSum = 0.0
        for i in s:
            sSum = sSum + i
        return sSum / len(s)

    def getSampleCategory(self):
        """返回样本的分类"""
        categorys = []
        for s in self._samples:
            minCategoryKey = None
            minLength = None
            for i in range(len(self._categorys)):
                l = s.lengthWith(self._categorys[i].getCenter())
                if minCategoryKey is None or l < minLength:
                    minCategoryKey = i
                    minLength = l
            categorys.append(minCategoryKey)
        return categorys

if __name__ == '__main__':
    """算是示例代码，演示如何使用文件提供的算法类。"""
    fileName = 'km_test_data.csv'

    createData(fileName, 100)

    logic = AlgorithmLogic()
    for data in readData(fileName):
        logic.addSampleData(data)

    logic.run()

    print('Central location:')
    for c in logic.getCategorys():
        print(c.getCenter().getData())

    print('Silhouette Coefficient: ' + str(logic.getSilhouetteCoefficient()))

    print('Detailed path movement information:')
    with open('km_test_move.csv', 'w', newline = '') as fp:
        write = csv.writer(fp)
        loc = logic.getCenterMoveDetails()
        headers = []
        for k in loc[0]:
            headers = headers + ['x', 'y']
        write.writerow(headers)
        for m in loc:
            print(m)
            rows = []
            for c in m:
                rows = rows + c
            write.writerow(rows)

    print('The sum of the center moving distances:')
    print(logic.getCenterMoveSteps())
